#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
模型下载工具

使用方法:
python download_models.py [模型名称]

支持的模型:
- u2net: U2-Net 完整模型 (约176MB)
- u2netp: U2-Net 轻量版模型 (约4.7MB)
- all: 下载所有可用模型
"""

import os
import sys
import argparse
from utils.model_downloader import download_model_if_needed

def main():
    parser = argparse.ArgumentParser(description="下载预训练模型工具")
    parser.add_argument("model", nargs="?", default="all", 
                        help="要下载的模型名称 (u2net, u2netp, all)")
    
    args = parser.parse_args()
    model_name = args.model.lower()
    
    if model_name == "all":
        print("开始下载所有模型...")
        models = ["u2net", "u2netp"]
        for model in models:
            download_model_if_needed(model)
    elif model_name in ["u2net", "u2netp"]:
        download_model_if_needed(model_name)
    else:
        print(f"不支持的模型: {model_name}")
        print("支持的模型: u2net, u2netp, all")
        return 1
    
    return 0

if __name__ == "__main__":
    sys.exit(main()) 